
include("../src/qmcmary.jl")
using ..qmcmary

using LinearAlgebra


"""
不使用稳定
"""
function rawgbar(allbmats, idx)
    Nx = size(allbmats[1])[1]
    Rs = I(Nx)
    for i in Base.OneTo(idx)
        Rs = allbmats[i]*Rs
    end
    Ls = Diagonal(ones(Nx))
    for i in length(allbmats):-1:(idx+1)
        Ls = Ls*allbmats[i]
    end
    return inv(Diagonal(ones(Nx)) + inv(Ls)*inv(Rs))
end


"""
rawgbar对比
"""
function bbar1()
    #创建ss
    Nx = 2
    #ehk = rand(ComplexF64, Nx, Nx)
    #ehk = Hermitian(ehk+adjoint(ehk))
    hk = lattice_chain(Nx, -1.0+0.0im)
    Ui = ones(Nx)
    the = rand(Nx)*pi
    nx = zeros(Nx)
    ny = zeros(Nx)
    nz = ones(Nx)
    #
    Nt = 20
    Ng = 5
    bidxs = Vector{Int}(undef, Ng)
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    sp = default_splitting(Nt, hk, Ui; Z2=true)
    sslen, bmats, allbmats, ss = initialize_SS(Nt, Ng, sp, hscfg, nx, ny, nz)
    #
    gs1 = gbar_scratch(ss)
    gs2 = rawgbar(allbmats, 0)
    println(isapprox.(gs1, gs2, atol=1e-6))
    #
    gs19 = down_prog_gbar(gs1, allbmats[20])
    gs192 = rawgbar(allbmats, 19)
    println(gs19, gs192)
    println(isapprox.(gs19, gs192, atol=1e-6))
    #
    scrollR2L(ss, bmats)
    gs1 = gbar_scratch(ss)
    gs2 = rawgbar(allbmats, 15)
    println(gs1, gs2)
    println(isapprox.(gs1, gs2, atol=1e-6))
    gs14 = down_prog_gbar(gs1, allbmats[15])
    gs13 = down_prog_gbar(gs14, allbmats[14])
    gs132 = rawgbar(allbmats, 13)
    println(gs13, gs132)
    println(isapprox.(gs13, gs132, atol=1e-6))
    #
    scrollL2R(ss, bmats)
    #
    gs1 = rawgbar(allbmats, 0)
    gs2 = rawgbar(allbmats, 1)
    #G_l 用来调整 B_l+1
    bbar1 = bbar_from_gbar(gs1, allbmats[1])
    #对照
    bbar2 = bbar_calc(ss, allbmats)
    println("bbar ", bbar1)
    println(bbar2[1, :, :])
    #
    #注意需要优化的是 - d P_{s}
    sig1 = det(I(4) + prod(reverse(allbmats)))
    println(sig1)
    println(-log(sig1))
    allbmats[1][1, 2] += 0.01
    sig2 = det(I(4) + prod(reverse(allbmats)))
    println(sig2)
    println(-log(sig2))
    println(-log(sig2) + log(sig1))
end


function gbar1()
    Nx = 72
    ehk = lattice_chain(Nx, -1)
    Ui = 6*ones(Nx)
    the = rand(Nx)*pi
    nx = zeros(Nx)
    ny = sin.(the)#[0.5sqrt(2), 0.5sqrt(2)]#
    nz = cos.(the)
    #
    Nt = 200
    Ng = 5
    bidxs = Vector{Int}(undef, Ng)
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    sslen, bmats, allbmats, ss = initialize_SS(Nt, Ng, ehk, hscfg, Ui, nx, ny, nz)
    #
    oldgss = missing
    for grpi in Base.OneTo(sslen)
        gss = gbar_scratch(ss)
        println("gss ", gss[1:3, 1:3])
        if !ismissing(oldgss)
            println(maximum(abs.(oldgss-gss)), " ", findfirst(ss.L))
        end
        for ni in Base.OneTo(Ng)
            gi = Nt - Ng*(grpi-1) - ni + 1
            bi = gi+1 > Nt ? 1 : gi+1
            #bbar2[bi, :, :] = bbar_from_gbar(gss, allbmats1[bi])
            #println(grpi, " ", gi)
            #println(gi)
            gss = down_prog_gbar(gss, allbmats[gi])
            #println(gss)
        end
        oldgss = gss
        scrollR2L(ss, [ss.B[sslen-grpi+1]])
        #println(sslen-grpi+1, " ", ss.L)
    end
end


bbar1()
#gbar1()
